{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "This is a companion notebook for the book [Deep Learning with Python, Third Edition](https://www.manning.com/books/deep-learning-with-python-third-edition). For readability, it only contains runnable code blocks and section titles, and omits everything else in the book: text paragraphs, figures, and pseudocode.\n\n**If you want to be able to follow what's going on, I recommend reading the notebook side by side with your copy of the book.**\n\nThe book's contents are available online at [deeplearningwithpython.io](https://deeplearningwithpython.io).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "!pip install keras keras-hub --upgrade -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# @title\n",
    "import os\n",
    "from IPython.core.magic import register_cell_magic\n",
    "\n",
    "@register_cell_magic\n",
    "def backend(line, cell):\n",
    "    current, required = os.environ.get(\"KERAS_BACKEND\", \"\"), line.split()[-1]\n",
    "    if current == required:\n",
    "        get_ipython().run_cell(cell)\n",
    "    else:\n",
    "        print(\n",
    "            f\"This cell requires the {required} backend. To run it, change KERAS_BACKEND to \"\n",
    "            f\"\\\"{required}\\\" at the top of the notebook, restart the runtime, and rerun the notebook.\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Text generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### A brief history of sequence generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Training a mini-GPT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Free up more GPU memory on the Jax and TensorFlow backends.\n",
    "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.00\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import keras\n",
    "import pathlib\n",
    "\n",
    "extract_dir = keras.utils.get_file(\n",
    "    fname=\"mini-c4\",\n",
    "    origin=(\n",
    "        \"https://hf.co/datasets/mattdangerw/mini-c4/resolve/main/mini-c4.zip\"\n",
    "    ),\n",
    "    extract=True,\n",
    ")\n",
    "extract_dir = pathlib.Path(extract_dir) / \"mini-c4\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "with open(extract_dir / \"shard0.txt\", \"r\") as f:\n",
    "    print(f.readline().replace(\"\\\\n\", \"\\n\")[:100])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import keras_hub\n",
    "import numpy as np\n",
    "\n",
    "vocabulary_file = keras.utils.get_file(\n",
    "    origin=\"https://hf.co/mattdangerw/spiece/resolve/main/vocabulary.proto\",\n",
    ")\n",
    "tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(vocabulary_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "tokenizer.tokenize(\"The quick brown fox.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "tokenizer.detokenize([450, 4996, 17354, 1701, 29916, 29889])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "batch_size = 64\n",
    "sequence_length = 256\n",
    "suffix = np.array([tokenizer.token_to_id(\"<|endoftext|>\")])\n",
    "\n",
    "def read_file(filename):\n",
    "    ds = tf.data.TextLineDataset(filename)\n",
    "    ds = ds.map(lambda x: tf.strings.regex_replace(x, r\"\\\\n\", \"\\n\"))\n",
    "    ds = ds.map(tokenizer, num_parallel_calls=8)\n",
    "    return ds.map(lambda x: tf.concat([x, suffix], -1))\n",
    "\n",
    "files = [str(file) for file in extract_dir.glob(\"*.txt\")]\n",
    "ds = tf.data.Dataset.from_tensor_slices(files)\n",
    "ds = ds.interleave(read_file, cycle_length=32, num_parallel_calls=32)\n",
    "ds = ds.rebatch(sequence_length + 1, drop_remainder=True)\n",
    "ds = ds.map(lambda x: (x[:-1], x[1:]))\n",
    "ds = ds.batch(batch_size).prefetch(8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "num_batches = 58746\n",
    "num_val_batches = 500\n",
    "num_train_batches = num_batches - num_val_batches\n",
    "val_ds = ds.take(num_val_batches).repeat()\n",
    "train_ds = ds.skip(num_val_batches).repeat()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Building the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from keras import layers\n",
    "\n",
    "class TransformerDecoder(keras.Layer):\n",
    "    def __init__(self, hidden_dim, intermediate_dim, num_heads):\n",
    "        super().__init__()\n",
    "        key_dim = hidden_dim // num_heads\n",
    "        self.self_attention = layers.MultiHeadAttention(\n",
    "            num_heads, key_dim, dropout=0.1\n",
    "        )\n",
    "        self.self_attention_layernorm = layers.LayerNormalization()\n",
    "        self.feed_forward_1 = layers.Dense(intermediate_dim, activation=\"relu\")\n",
    "        self.feed_forward_2 = layers.Dense(hidden_dim)\n",
    "        self.feed_forward_layernorm = layers.LayerNormalization()\n",
    "        self.dropout = layers.Dropout(0.1)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        residual = x = inputs\n",
    "        x = self.self_attention(query=x, key=x, value=x, use_causal_mask=True)\n",
    "        x = self.dropout(x)\n",
    "        x = x + residual\n",
    "        x = self.self_attention_layernorm(x)\n",
    "        residual = x\n",
    "        x = self.feed_forward_1(x)\n",
    "        x = self.feed_forward_2(x)\n",
    "        x = self.dropout(x)\n",
    "        x = x + residual\n",
    "        x = self.feed_forward_layernorm(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from keras import ops\n",
    "\n",
    "class PositionalEmbedding(keras.Layer):\n",
    "    def __init__(self, sequence_length, input_dim, output_dim):\n",
    "        super().__init__()\n",
    "        self.token_embeddings = layers.Embedding(input_dim, output_dim)\n",
    "        self.position_embeddings = layers.Embedding(sequence_length, output_dim)\n",
    "\n",
    "    def call(self, inputs, reverse=False):\n",
    "        if reverse:\n",
    "            token_embeddings = self.token_embeddings.embeddings\n",
    "            return ops.matmul(inputs, ops.transpose(token_embeddings))\n",
    "        positions = ops.cumsum(ops.ones_like(inputs), axis=-1) - 1\n",
    "        embedded_tokens = self.token_embeddings(inputs)\n",
    "        embedded_positions = self.position_embeddings(positions)\n",
    "        return embedded_tokens + embedded_positions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "keras.config.set_dtype_policy(\"mixed_float16\")\n",
    "\n",
    "vocab_size = tokenizer.vocabulary_size()\n",
    "hidden_dim = 512\n",
    "intermediate_dim = 2056\n",
    "num_heads = 8\n",
    "num_layers = 8\n",
    "\n",
    "inputs = keras.Input(shape=(None,), dtype=\"int32\", name=\"inputs\")\n",
    "embedding = PositionalEmbedding(sequence_length, vocab_size, hidden_dim)\n",
    "x = embedding(inputs)\n",
    "x = layers.LayerNormalization()(x)\n",
    "for i in range(num_layers):\n",
    "    x = TransformerDecoder(hidden_dim, intermediate_dim, num_heads)(x)\n",
    "outputs = embedding(x, reverse=True)\n",
    "mini_gpt = keras.Model(inputs, outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Pretraining the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "class WarmupSchedule(keras.optimizers.schedules.LearningRateSchedule):\n",
    "    def __init__(self):\n",
    "        self.rate = 2e-4\n",
    "        self.warmup_steps = 1_000.0\n",
    "\n",
    "    def __call__(self, step):\n",
    "        step = ops.cast(step, dtype=\"float32\")\n",
    "        scale = ops.minimum(step / self.warmup_steps, 1.0)\n",
    "        return self.rate * scale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "schedule = WarmupSchedule()\n",
    "x = range(0, 5_000, 100)\n",
    "y = [ops.convert_to_numpy(schedule(step)) for step in x]\n",
    "plt.plot(x, y)\n",
    "plt.xlabel(\"Train Step\")\n",
    "plt.ylabel(\"Learning Rate\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# \u26a0\ufe0fNOTE\u26a0\ufe0f: If you can run the following with a Colab Pro GPU, we suggest you\n",
    "# do so. This fit() call will take many hours on free tier GPUs. You can also\n",
    "# reduce steps_per_epoch to try the code with a less trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "num_epochs = 8\n",
    "steps_per_epoch = num_train_batches // num_epochs\n",
    "validation_steps = num_val_batches\n",
    "\n",
    "mini_gpt.compile(\n",
    "    optimizer=keras.optimizers.Adam(schedule),\n",
    "    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    metrics=[\"accuracy\"],\n",
    ")\n",
    "mini_gpt.fit(\n",
    "    train_ds,\n",
    "    validation_data=val_ds,\n",
    "    epochs=num_epochs,\n",
    "    steps_per_epoch=steps_per_epoch,\n",
    "    validation_steps=validation_steps,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Generative decoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def generate(prompt, max_length=64):\n",
    "    tokens = list(ops.convert_to_numpy(tokenizer(prompt)))\n",
    "    prompt_length = len(tokens)\n",
    "    for _ in range(max_length - prompt_length):\n",
    "        prediction = mini_gpt(ops.convert_to_numpy([tokens]))\n",
    "        prediction = ops.convert_to_numpy(prediction[0, -1])\n",
    "        tokens.append(np.argmax(prediction).item())\n",
    "    return tokenizer.detokenize(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "prompt = \"A piece of advice\"\n",
    "generate(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def compiled_generate(prompt, max_length=64):\n",
    "    tokens = list(ops.convert_to_numpy(tokenizer(prompt)))\n",
    "    prompt_length = len(tokens)\n",
    "    tokens = tokens + [0] * (max_length - prompt_length)\n",
    "    for i in range(prompt_length, max_length):\n",
    "        prediction = mini_gpt.predict(np.array([tokens]), verbose=0)\n",
    "        prediction = prediction[0, i - 1]\n",
    "        tokens[i] = np.argmax(prediction).item()\n",
    "    return tokenizer.detokenize(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import timeit\n",
    "tries = 10\n",
    "timeit.timeit(lambda: compiled_generate(prompt), number=tries) / tries"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Sampling strategies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def compiled_generate(prompt, sample_fn, max_length=64):\n",
    "    tokens = list(ops.convert_to_numpy(tokenizer(prompt)))\n",
    "    prompt_length = len(tokens)\n",
    "    tokens = tokens + [0] * (max_length - prompt_length)\n",
    "    for i in range(prompt_length, max_length):\n",
    "        prediction = mini_gpt.predict(np.array([tokens]), verbose=0)\n",
    "        prediction = prediction[0, i - 1]\n",
    "        next_token = ops.convert_to_numpy(sample_fn(prediction))\n",
    "        tokens[i] = np.array(next_token).item()\n",
    "    return tokenizer.detokenize(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def greedy_search(preds):\n",
    "    return ops.argmax(preds)\n",
    "\n",
    "compiled_generate(prompt, greedy_search)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def random_sample(preds, temperature=1.0):\n",
    "    preds = preds / temperature\n",
    "    return keras.random.categorical(preds[None, :], num_samples=1)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "compiled_generate(prompt, random_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "compiled_generate(prompt, partial(random_sample, temperature=2.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "compiled_generate(prompt, partial(random_sample, temperature=0.8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "compiled_generate(prompt, partial(random_sample, temperature=0.2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def top_k(preds, k=5, temperature=1.0):\n",
    "    preds = preds / temperature\n",
    "    top_preds, top_indices = ops.top_k(preds, k=k, sorted=False)\n",
    "    choice = keras.random.categorical(top_preds[None, :], num_samples=1)[0]\n",
    "    return ops.take_along_axis(top_indices, choice, axis=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "compiled_generate(prompt, partial(top_k, k=5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "compiled_generate(prompt, partial(top_k, k=20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "compiled_generate(prompt, partial(top_k, k=5, temperature=0.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Using a pretrained LLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Text generation with the Gemma model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import kagglehub\n",
    "\n",
    "kagglehub.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm = keras_hub.models.CausalLM.from_preset(\n",
    "    \"gemma3_1b\",\n",
    "    dtype=\"float32\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.summary(line_length=80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.compile(sampler=\"greedy\")\n",
    "gemma_lm.generate(\"A piece of advice\", max_length=40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.generate(\"How can I make brownies?\", max_length=40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.generate(\n",
    "    \"The following brownie recipe is easy to make in just a few \"\n",
    "    \"steps.\\n\\nYou can start by\",\n",
    "    max_length=40,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.generate(\n",
    "    \"Tell me about the 542nd president of the United States.\",\n",
    "    max_length=40,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Instruction fine-tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "PROMPT_TEMPLATE = \"\"\"\"[instruction]\\n{}[end]\\n[response]\\n\"\"\"\n",
    "RESPONSE_TEMPLATE = \"\"\"{}[end]\"\"\"\n",
    "\n",
    "dataset_path = keras.utils.get_file(\n",
    "    origin=(\n",
    "        \"https://hf.co/datasets/databricks/databricks-dolly-15k/\"\n",
    "        \"resolve/main/databricks-dolly-15k.jsonl\"\n",
    "    ),\n",
    ")\n",
    "data = {\"prompts\": [], \"responses\": []}\n",
    "with open(dataset_path) as file:\n",
    "    for line in file:\n",
    "        features = json.loads(line)\n",
    "        if features[\"context\"]:\n",
    "            continue\n",
    "        data[\"prompts\"].append(PROMPT_TEMPLATE.format(features[\"instruction\"]))\n",
    "        data[\"responses\"].append(RESPONSE_TEMPLATE.format(features[\"response\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "data[\"prompts\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "data[\"responses\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "ds = tf.data.Dataset.from_tensor_slices(data).shuffle(2000).batch(2)\n",
    "val_ds = ds.take(100)\n",
    "train_ds = ds.skip(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "preprocessor = gemma_lm.preprocessor\n",
    "preprocessor.sequence_length = 512\n",
    "batch = next(iter(train_ds))\n",
    "x, y, sample_weight = preprocessor(batch)\n",
    "x[\"token_ids\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x[\"padding_mask\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "sample_weight.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x[\"token_ids\"][0, :5], y[0, :5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Low-Rank Adaptation (LoRA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.backbone.enable_lora(rank=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.summary(line_length=80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.compile(\n",
    "    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    optimizer=keras.optimizers.Adam(5e-5),\n",
    "    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n",
    ")\n",
    "gemma_lm.fit(train_ds, validation_data=val_ds, epochs=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.generate(\n",
    "    \"[instruction]\\nHow can I make brownies?[end]\\n\"\n",
    "    \"[response]\\n\",\n",
    "    max_length=512,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.generate(\n",
    "    \"[instruction]\\nWhat is a proper noun?[end]\\n\"\n",
    "    \"[response]\\n\",\n",
    "    max_length=512,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.generate(\n",
    "    \"[instruction]\\nWho is the 542nd president of the United States?[end]\\n\"\n",
    "    \"[response]\\n\",\n",
    "    max_length=512,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Going further with LLMs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Reinforcement Learning with Human Feedback (RLHF)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Using a chatbot trained with RLHF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# \u26a0\ufe0fNOTE\u26a0\ufe0f: If you are running on the free tier Colab GPUs, you will need to\n",
    "# restart your runtime and run the notebook from here to free up memory for\n",
    "# this 4 billion parameter model.\n",
    "import os\n",
    "\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
    "# Free up more GPU memory on the Jax and TensorFlow backends.\n",
    "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.00\"\n",
    "\n",
    "import keras\n",
    "import keras_hub\n",
    "import kagglehub\n",
    "import numpy as np\n",
    "\n",
    "kagglehub.login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm = keras_hub.models.CausalLM.from_preset(\n",
    "    \"gemma3_instruct_4b\",\n",
    "    dtype=\"bfloat16\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "PROMPT_TEMPLATE = \"\"\"<start_of_turn>user\n",
    "{}<end_of_turn>\n",
    "<start_of_turn>model\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "prompt = \"Why can't you assign values in Jax tensors? Be brief!\"\n",
    "gemma_lm.generate(PROMPT_TEMPLATE.format(prompt), max_length=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "prompt = \"Who is the 542nd president of the United States?\"\n",
    "gemma_lm.generate(PROMPT_TEMPLATE.format(prompt), max_length=512)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Multimodal LLMs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "image_url = (\n",
    "    \"https://github.com/mattdangerw/keras-nlp-scripts/\"\n",
    "    \"blob/main/learned-python.png?raw=true\"\n",
    ")\n",
    "image_path = keras.utils.get_file(origin=image_url)\n",
    "\n",
    "image = np.array(keras.utils.load_img(image_path))\n",
    "plt.axis(\"off\")\n",
    "plt.imshow(image)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.preprocessor.max_images_per_prompt = 1\n",
    "gemma_lm.preprocessor.sequence_length = 512\n",
    "prompt = \"What is going on in this image? Be concise!<start_of_image>\"\n",
    "gemma_lm.generate({\n",
    "    \"prompts\": PROMPT_TEMPLATE.format(prompt),\n",
    "    \"images\": [image],\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "prompt = \"What is the snake wearing?<start_of_image>\"\n",
    "gemma_lm.generate({\n",
    "    \"prompts\": PROMPT_TEMPLATE.format(prompt),\n",
    "    \"images\": [image],\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "##### Foundation models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Retrieval Augmented Generation (RAG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### \"Reasoning\" models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "prompt = \"\"\"Judy wrote a 2-page letter to 3 friends twice a week for 3 months.\n",
    "How many letters did she write?\n",
    "Be brief, and add \"ANSWER:\" before your final answer.\"\"\"\n",
    "\n",
    "gemma_lm.compile(sampler=\"random\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.generate(PROMPT_TEMPLATE.format(prompt))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "gemma_lm.generate(PROMPT_TEMPLATE.format(prompt))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Where are LLMs heading next?"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "chapter16_text-generation",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}